#ifndef __MATVEC__
#define __MATVEC__

#include "constants.h"
#include "OpenPFEM.h"
#include "petscmat.h"
#include "femfunction.h"

typedef struct CSR_MATRIX_
{
    INT NumRows;
    INT NumColumns;
    INT NumEntries;
    INT *RowPtr;
    INT *KCol;
    DOUBLE *Entries;
} CSRMAT;

typedef struct MATRIX_
{
    /* 矩阵设置 */
    DATATYPE type;
    MPI_Comm comm;
    void *ops;

    /* 矩阵性质 */
    INT global_nrows, global_ncols; // 全局行/列数
    INT local_nrows, local_ncols;   // 局部行/列数
    bool symmetric;                 // 对称与否
    bool delete_dirichletbd;        // 删除了delete边界与否

    /* 基本存储格式 */
    bool if_matrix;           // 是否有OpenPFEM格式的数据存在
    INT rows_start, rows_end; // 本进程行：[rows_start : rows_end-1]
    INT cols_start, cols_end; // 本进程列：[cols_start : cols_end-1]
    CSRMAT *diag;             // 对角部分
    CSRMAT *ndiag;            // 非对角部分
    INT *ndiag_globalcols;    // 非对角部分列编号对应的全局编号
    // 乘法操作需要的信息（单向量)
    MPI_Comm mult_comm, inverse_mult_comm;          // 矩阵乘和矩阵转置乘时分别需要的通信器
    INT mult_sendtotalnum, mult_recvtotalnum;       // 矩阵乘时所发送和收到的数据个数
    INT *mult_sendnum, *mult_recvnum;               // 需要发送给每个neig的个数
    INT *mult_sendnum_displs, *mult_recvnum_displs; // 上面变量的displs
    INT *mult_sendindex;                            // 发送的数据本地index
    DOUBLE *mult_sendtmp, *mult_recvtmp;            // 用以存储数据的空间（仅针对单向量大小）

    /* PETSc */
    bool if_petsc;    // 是否有Petsc格式的数据存在
    void *data_petsc; // 具体数据（格式Mat)

    /* assemble */
    void *assemble_info;
} MATRIX;

typedef struct VECTOR_
{
    // 向量设置
    DATATYPE type;
    MPI_Comm comm;
    void *ops;

    // 向量性质
    INT local_length;        // 局部长度
    INT global_length;       // 全局长度
    INT start, end;          // 含有全局编号[start:end-1]
    bool delete_dirichletbd; // 删除了delete边界与否

    // 基本存储格式
    bool if_vector; // 是否有OpenPFEM格式的数据存在
    DOUBLE *data;   // 具体数据

    // PETSc
    bool if_petsc;    // 是否有Petsc格式的数据存在
    void *data_petsc; // 具体数据（格式Vec)
} VECTOR;

typedef struct VECTORS_
{
    // 向量设置
    DATATYPE type;
    MPI_Comm comm;
    void *ops;

    // 向量性质
    INT nvecs;
    INT local_length;
    INT global_length;
    INT start, end;
    INT start_cols, end_cols;
    bool delete_dirichletbd; // 删除了delete边界与否

    // 基本存储格式
    bool if_vectors;
    DOUBLE *data;

    // SLEPc
    bool if_slepc;
    void *data_slepc;
} VECTORS;

typedef struct MATRIX_OPS_
{
    void (*create)(MATRIX *);
    void (*destroy)(MATRIX *);

    void (*sturcturecompare)(MATRIX *, MATRIX *, NONZERO_STRUCT *);
    void (*axpby)(DOUBLE, MATRIX *, DOUBLE, MATRIX *, NONZERO_STRUCT);
    void (*scale)(DOUBLE, MATRIX *);

    void (*vectorcreate)(VECTOR **, MATRIX *);
    void (*vectorcreate_T)(VECTOR **, MATRIX *);
    void (*matrixvectormult)(MATRIX *, VECTOR *, VECTOR *);
    void (*matrixvectormult_T)(MATRIX *, VECTOR *, VECTOR *);
    void (*matrixvectormult_add)(MATRIX *, VECTOR *, VECTOR *);
    void (*matrixvectormult_sub)(MATRIX *, VECTOR *, VECTOR *);

    void (*vectorscreate)(VECTORS **, MATRIX *, INT);
    void (*vectorscreate_T)(VECTORS **, MATRIX *, INT);
    void (*matrixvectorsmult)(MATRIX *, VECTORS *, VECTORS *);
    void (*matrixvectorsmult_T)(MATRIX *, VECTORS *, VECTORS *);
} MATRIX_OPS;

typedef struct VECTOR_OPS_
{
    void (*create)(VECTOR *);
    void (*destroy)(VECTOR *);
    void (*getarray)(VECTOR *, DOUBLE **);
    void (*scale)(DOUBLE, VECTOR *);
    void (*axpby)(DOUBLE, VECTOR *, DOUBLE, VECTOR *);
    void (*norm)(VECTOR *, DOUBLE *);
} VECTOR_OPS;

typedef struct VECTORS_OPS_
{
    void (*create)(VECTORS *);
    void (*destroy)(VECTORS *);
    void (*getarray)(VECTORS *, DOUBLE **);
    void (*restorearray)(VECTORS *, DOUBLE **);
} VECTORS_OPS;

#define CSRMatDestory(csrmat)                 \
    do                                        \
    {                                         \
        if ((csrmat) != NULL)                 \
        {                                     \
            OpenPFEM_Free((csrmat)->RowPtr);  \
            OpenPFEM_Free((csrmat)->KCol);    \
            OpenPFEM_Free((csrmat)->Entries); \
            OpenPFEM_Free((csrmat));          \
        }                                     \
    } while (0)

/////////////////////////////////////////////////////////////////////////////
/*                                 Matrix                                  */
/////////////////////////////////////////////////////////////////////////////
void MatrixCreate(MATRIX **Matrix, DATATYPE type);
void MatrixDestroy(MATRIX **matrix);
void MatrixCompress(MATRIX *matrix);
void MatrixPrint(MATRIX *matrix, INT printrank);
void MatrixPrint2Matlab(MATRIX *mat, INT printrank);
void MatrixOutput(MATRIX *matrix, const char *name);
void MatrixRowsReOrder(MATRIX **matrix, INT *old_gindex, INT rownum, BRIDGE *bridge);
// 把matrix的type及其内部数据转化为target_type，remain_current表示需不需要保留当前类型数据
void MatrixConvert(MATRIX *matrix, DATATYPE target_type, bool remain_current);
void MatrixOpsCreate(MATRIX *matrix, DATATYPE datatype);
void MatrixSturctureCompare(MATRIX *A, MATRIX *B, NONZERO_STRUCT *nonzero_struct);
void MatrixScale(DOUBLE alpha, MATRIX *A);
void MatrixAxpby(DOUBLE alpha, MATRIX *A, DOUBLE beta, MATRIX *B, NONZERO_STRUCT nonzero_struct);

/////////////////////////////////////////////////////////////////////////////
/*                                 Vector                                  */
/////////////////////////////////////////////////////////////////////////////
void VectorCreate(VECTOR **vector, DATATYPE type);
void VectorDestroy(VECTOR **vector);
void VectorPrint(VECTOR *vector, INT printrank);
void VectorOutput(VECTOR *vector, const char *name);
// 把vector的type及其内部数据转化为target_type，remain_current表示需不需要保留当前类型数据
void VectorConvert(VECTOR *vector, DATATYPE target_type, bool remain_current);
void VectorOpsCreate(VECTOR *vector, DATATYPE type);
void VectorGetArray(VECTOR *vector, DOUBLE **data);
void VectorScale(DOUBLE alpha, VECTOR *x);
void VectorAxpby(DOUBLE alpha, VECTOR *x, DOUBLE beta, VECTOR *y);
void VectorNorm(VECTOR *vector, DOUBLE *norm);
// 用于非零Dirichlet边界问题迭代求解 将右端项设置为初值
void SetRhsAsInitial(VECTOR *Rhs, VECTOR *Solution);

/////////////////////////////////////////////////////////////////////////////
/*                                Vectors                                  */
/////////////////////////////////////////////////////////////////////////////
void VectorsCreate(VECTORS **vectors, DATATYPE type);
void VectorsSetRange(VECTORS *vectors, INT start, INT end);
void VectorsDestroy(VECTORS **vectors);
void VectorsOpsCreate(VECTORS *vectors, DATATYPE type);
void VectorsGetArray(VECTORS *vectors, DOUBLE **data);
void VectorsRestoreArray(VECTORS *vectors, DOUBLE **data);
void VectorsAxpby(DOUBLE alpha, VECTORS *x, DOUBLE beta, VECTORS *y);

/////////////////////////////////////////////////////////////////////////////
/*                            Matrix - Vector                              */
/////////////////////////////////////////////////////////////////////////////
void VectorCreateByMatrix(VECTOR **vector, MATRIX *matrix);
void VectorCreateByMatrixTranspose(VECTOR **vector, MATRIX *matrix);
void MatrixVectorMult(MATRIX *A, VECTOR *x, VECTOR *y);
void MatrixTransposeVectorMult(MATRIX *A, VECTOR *x, VECTOR *y);
void MatrixVectorMultAdd(MATRIX *A, VECTOR *x, VECTOR *y); // y += Ax
void MatrixVectorMultSub(MATRIX *A, VECTOR *x, VECTOR *y); // y -= Ax

/////////////////////////////////////////////////////////////////////////////
/*                           Matrix - Vectors                              */
/////////////////////////////////////////////////////////////////////////////
void VectorsCreateByMatrix(VECTORS **vectors, MATRIX *matrix, INT num);
void VectorsCreateByMatrixTranspose(VECTORS **vectors, MATRIX *matrix, INT num);
void MatrixVectorsMult(MATRIX *A, VECTORS *x, VECTORS *y);
void MatrixTransposeVectorsMult(MATRIX *A, VECTORS *x, VECTORS *y);

/////////////////////////////////////////////////////////////////////////////
/*                           Vectors - Vector                              */
/////////////////////////////////////////////////////////////////////////////

/////////////////////////////////////////////////////////////////////////////
/*                          FemFunction(s) app                             */
/////////////////////////////////////////////////////////////////////////////

void VectorGetFEMFunction(VECTOR *vector, FEMFUNCTION *femfunc, INT index);
void VectorsGetFEMFunction(VECTORS *vectors, FEMFUNCTION *femfunc);
void FEMFunctionGetVector(FEMFUNCTION *femvec, VECTOR *vector);

#endif
